{
    "cells": [
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "To train this agent, click _Runtime_ and press _Run all_. Make sure you've enabled a free Tesla T4 GPU!\n",
                "\n",
                "<div class=\"align-center\">\n",
                "<a href=\"https://github.com/openpipe/art\"><img src=\"https://github.com/openpipe/art/raw/main/assets/ART_pill.png\" height=\"50\"></a>\n",
                "<a href=\"https://discord.gg/zbBHRUpwf4\"><img src=\"https://github.com/openpipe/art/raw/main/assets/Discord_pill.png\" height=\"50\"></a>\n",
                "<a href=\"https://art.openpipe.ai\"><img src=\"https://github.com/openpipe/art/raw/main/assets/Documentation_pill.png\" height=\"50\"></a>\n",
                "\n",
                "Questions? Join the Discord and ask away! For feature requests or to leave a star, visit our [GitHub](https://github.com/openpipe/art).\n",
                "\n",
                "</div>\n",
                "\n",
                "<a href=\"https://art.openpipe.ai/\"><img src=\"https://github.com/openpipe/art/raw/main/assets/Header_separator.png\" height=\"5\"></a>\n",
                "\n",
                "This notebook shows how to train a Qwen 2.5 7B model to play Temporal Clue, a simplified version of the game Clue that relies on temporal reasoning. It will demonstrate how to set up a single-turn agent, how to train it, and how to evaluate it.\n",
                "\n",
                "Completions will be logged to OpenPipe, and metrics will be logged to Weights & Biases.\n",
                "\n",
                "You will learn how to construct an [agentic environment](#Environment), how to define a [rollout](#Rollout), and how to run a [training loop](#Loop).\n",
                "\n",
                "SIDE NOTE: If you're curious about how Temporal Clue works and want to learn how we trained a model to beat o3-mini, check out the [blog post](https://openpipe.ai/blog/using-grpo-to-beat-o1-o3-mini-and-r1-on-temporal-clue)!\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Installation"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "%%capture\n",
                "!uv pip install openpipe-art==0.3.11.post2 \"gql<4\" --prerelease allow --no-cache-dir"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Environment Variables\n",
                "\n",
                "Later on in the notebook, we'll be creating a model that can automatically logs metrics to Weights & Biases. In order to do so, you'll need to provide your Weights & Biases API key as an environment variable.\n",
                "\n",
                "You can also optionally initiate an OpenPipe client to report completions to a [dashboard](https://app.openpipe.ai) to get a feel for what the completions your model is generating look like, and how they change over time. Logging to OpenPipe is free, but is not required for training!\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import os\n",
                "\n",
                "# Optional\n",
                "WANDB_API_KEY = \"\"\n",
                "if WANDB_API_KEY:\n",
                "    os.environ[\"WANDB_API_KEY\"] = WANDB_API_KEY\n",
                "\n",
                "# Optional\n",
                "OPENPIPE_API_KEY = \"\"\n",
                "if OPENPIPE_API_KEY:\n",
                "    os.environ[\"OPENPIPE_API_KEY\"] = OPENPIPE_API_KEY"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Agentic Environment\n",
                "\n",
                "<a name=\"Environment\"></a>\n",
                "\n",
                "ART allows your agent to learn by interacting with its environment. In this example, we'll create an environment in which the agent can play Temporal Clue.\n",
                "\n",
                "Feel free to read as much or as little of this section's code as you'd like. The important thing to understand is that we're defining the rules of this agent's environment. In many cases, this will already be defined by the task you're trying to solve, but if you need to define a custom environment, this is how you do it."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import art\n",
                "import asyncio\n",
                "from dotenv import load_dotenv\n",
                "import json\n",
                "import requests\n",
                "import random\n",
                "import re\n",
                "from typing import TypedDict\n",
                "\n",
                "load_dotenv()\n",
                "\n",
                "\n",
                "class TemporalCluePuzzle(TypedDict):\n",
                "    num_clues: int\n",
                "    prompt: str\n",
                "    solution: dict[str, str]\n",
                "\n",
                "\n",
                "# download the puzzles from the github repo\n",
                "puzzles_url = \"https://raw.githubusercontent.com/openpipe/art/main/examples/data/temporal-clue/puzzles.json\"\n",
                "puzzles_response = requests.get(puzzles_url)\n",
                "\n",
                "puzzles: list[TemporalCluePuzzle] = json.loads(puzzles_response.text)\n",
                "val_puzzles = puzzles[:64]\n",
                "test_puzzles = puzzles[64:128]\n",
                "train_puzzles = puzzles[128:]\n",
                "random.seed(42)\n",
                "random.shuffle(train_puzzles)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Creating a Model\n",
                "\n",
                "Now that we've defined the rules of our environment, we can create a model that will learn to play Temporal Clue. We'll use a Qwen 2.5 7B model for this example. The `name` parameter will be associated with a wandb run, and the `base_model` parameter is the model that we'll be training a LoRA on top of.\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from art.local import LocalBackend\n",
                "\n",
                "model = art.TrainableModel(\n",
                "    name=\"001\",\n",
                "    project=\"temporal-clue\",\n",
                "    base_model=\"Qwen/Qwen2.5-7B-Instruct\",\n",
                "    _internal_config={\"init_args\": {\"gpu_memory_utilization\": 0.775}},\n",
                ")\n",
                "await model.register(LocalBackend(path=\"./.art\"))"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Defining a Rollout\n",
                "\n",
                "<a name=\"Rollout\"></a>\n",
                "\n",
                "A rollout is a single episode of an agent performing its task. It generates one or more trajectories, which are lists of messages and choices.\n",
                "\n",
                "In this example, the rollout function loads a Temporal Clue prompt, and the agent gives its best guess at the solution. It then returns a trajectory which contains the `system` and `user` messages presented to the agent, as well as the `choices` that the agent made.\n",
                "\n",
                "When the game is finished the `reward` for the agent's performance is calculated based on whether the agent's final answer matches the solution.\n",
                "\n",
                "This rollout function will be called many times in parallel during each step of the training loop.\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "async def rollout(model: art.Model, puzzle: TemporalCluePuzzle) -> art.Trajectory:\n",
                "    messages: art.Messages = [{\"role\": \"user\", \"content\": puzzle[\"prompt\"]}]\n",
                "    client = model.openai_client()\n",
                "    chat_completion = await client.chat.completions.create(\n",
                "        messages=messages, model=model.name\n",
                "    )\n",
                "    choice = chat_completion.choices[0]\n",
                "    content = choice.message.content\n",
                "    assert isinstance(content, str)\n",
                "    num_correct = 0\n",
                "    for key, value in puzzle[\"solution\"].items():\n",
                "        if matches := re.findall(rf\"{key}\\. ([A-Za-z \\.:-]+)\", content):\n",
                "            match = matches[-1]\n",
                "            if match.strip().lower() == value.lower():\n",
                "                num_correct += 1\n",
                "    reward = acc = num_correct / len(puzzle[\"solution\"])\n",
                "    return art.Trajectory(\n",
                "        messages_and_choices=[*messages, choice], reward=reward, metrics={\"acc\": acc}\n",
                "    )"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "<a name=\"Loop\"></a>\n",
                "\n",
                "### Training Loop\n",
                "\n",
                "The training loop is where the magic happens. For each of the 1000 steps defined below, the rollout function will be called 2 times in parallel on each of the 64 validation puzzles, and 50 times in parallel on groups of 4 train puzzles. This means that 128 validation and 200 train games will be played at once. Each game will produce a trajectory, which will be used to update the model or record metrics.\n",
                "\n",
                "The `gather_trajectory_groups` functions will wait for all of the trajectories to be generated, then we will delete all but the most recent checkpoint and train the model on the generated `train` trajectories.\n",
                "\n",
                "Inference will be blocked until the training is complete."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "stride = 4\n",
                "for i in range(await model.get_step(), 1_000):\n",
                "    val_groups, train_groups = await asyncio.gather(\n",
                "        art.gather_trajectory_groups(\n",
                "            (\n",
                "                art.TrajectoryGroup(rollout(model, puzzle) for _ in range(2))\n",
                "                for puzzle in val_puzzles\n",
                "            ),\n",
                "            pbar_desc=\"val\",\n",
                "        ),\n",
                "        art.gather_trajectory_groups(\n",
                "            (\n",
                "                art.TrajectoryGroup(rollout(model, puzzle) for _ in range(50))\n",
                "                for puzzle in train_puzzles[i * stride : (i + 1) * stride]\n",
                "            ),\n",
                "            pbar_desc=\"train\",\n",
                "        ),\n",
                "    )\n",
                "    await model.log(val_groups)\n",
                "    await model.delete_checkpoints()\n",
                "    await model.train(\n",
                "        train_groups,\n",
                "        config=art.TrainConfig(learning_rate=5e-5),\n",
                "    )"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "<div class=\"align-center\">\n",
                "<a href=\"https://github.com/openpipe/art\"><img src=\"https://github.com/openpipe/art/raw/notebooks/assets/ART_pill.png\" height=\"50\"></a>\n",
                "<a href=\"https://discord.gg/zbBHRUpwf4\"><img src=\"https://github.com/openpipe/art/raw/notebooks/assets/Discord_pill.png\" height=\"50\"></a>\n",
                "<a href=\"https://openpipe.ai/blog/art-e-mail-agent\"><img src=\"https://github.com/openpipe/art/raw/main/assets/ART_E_pill.png\" height=\"50\"></a>\n",
                "\n",
                "Questions? Join the Discord and ask away! For feature requests or to leave a star, visit our [GitHub](https://github.com/openpipe/art).\n",
                "\n",
                "</div>\n"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": ".venv",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.10.13"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
